{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Flash Evaluation on DARPA E3 Theia Dataset: \n",
    "\n",
    "This notebook is specifically designed for the evaluation of Flash on the DARPA E3 Theia dataset. Notably, the Theia dataset is characterized as a node-level dataset. In our analysis, Flash is configured to operate in a node-level setting to aptly assess this dataset. A key aspect to note is that the Theia dataset lacks certain essential node attributes for specific node types. This limitation means that Flash cannot be operated in a decoupled mode with offline GNN embeddings for this dataset. Consequently, we employ an online GNN coupled with word2vec semantic embeddings to achieve effective evaluation results for this dataset.\n",
    "\n",
    "## Dataset Access: \n",
    "- Access the Theia dataset via the following link: [Theia Dataset](https://drive.google.com/drive/folders/1QlbUFWAGq3Hpl8wVdzOdIoZLFxkII4EK).\n",
    "- The dataset files will be downloaded automatically by the script.\n",
    "\n",
    "## Data Parsing and Execution:\n",
    "- The script is designed to automatically parse the downloaded data files.\n",
    "- Execute all cells within this notebook to obtain the evaluation results.\n",
    "\n",
    "## Model Training and Execution Flexibility:\n",
    "- The notebook is configured to use pre-trained model weights by default.\n",
    "- It also provides the option to set parameters for independently training Graph Neural Networks (GNNs) and word2vec models.\n",
    "- These newly trained models can then be utilized for a comprehensive evaluation of the dataset.\n",
    "\n",
    "Adhere to these steps for a detailed and effective analysis of the Theia dataset using Flash.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "F1op-CbyLuN4",
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "import torch\n",
    "from torch_geometric.data import Data\n",
    "import os\n",
    "import torch.nn.functional as F\n",
    "import json\n",
    "import warnings\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.manifold import TSNE\n",
    "warnings.filterwarnings('ignore')\n",
    "from torch_geometric.loader import NeighborLoader\n",
    "import multiprocessing\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gdown\n",
    "urls = [\"https://drive.google.com/file/d/10cecNtR3VsHfV0N-gNEeoVeB89kCnse5/view?usp=drive_link\",\n",
    "        \"https://drive.google.com/file/d/1Kadc6CUTb4opVSDE4x6RFFnEy0P1cRp0/view?usp=drive_link\"]\n",
    "for url in urls:\n",
    "    gdown.download(url, quiet=False, use_cookies=False, fuzzy=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "Train = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "nM7KaeCbA_mQ",
    "tags": []
   },
   "outputs": [],
   "source": [
    "from pprint import pprint\n",
    "import gzip\n",
    "from sklearn.manifold import TSNE\n",
    "import json\n",
    "import copy\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "def extract_uuid(line):\n",
    "    pattern_uuid = re.compile(r'uuid\\\":\\\"(.*?)\\\"')\n",
    "    return pattern_uuid.findall(line)\n",
    "\n",
    "def extract_subject_type(line):\n",
    "    pattern_type = re.compile(r'type\\\":\\\"(.*?)\\\"')\n",
    "    return pattern_type.findall(line)\n",
    "\n",
    "def show(file_path):\n",
    "    print(f\"Processing {file_path}\")\n",
    "\n",
    "def extract_edge_info(line):\n",
    "    pattern_src = re.compile(r'subject\\\":{\\\"com.bbn.tc.schema.avro.cdm18.UUID\\\":\\\"(.*?)\\\"}')\n",
    "    pattern_dst1 = re.compile(r'predicateObject\\\":{\\\"com.bbn.tc.schema.avro.cdm18.UUID\\\":\\\"(.*?)\\\"}')\n",
    "    pattern_dst2 = re.compile(r'predicateObject2\\\":{\\\"com.bbn.tc.schema.avro.cdm18.UUID\\\":\\\"(.*?)\\\"}')\n",
    "    pattern_type = re.compile(r'type\\\":\\\"(.*?)\\\"')\n",
    "    pattern_time = re.compile(r'timestampNanos\\\":(.*?),')\n",
    "\n",
    "    edge_type = extract_subject_type(line)[0]\n",
    "    timestamp = pattern_time.findall(line)[0]\n",
    "    src_id = pattern_src.findall(line)\n",
    "\n",
    "    if len(src_id) == 0:\n",
    "        return None, None, None, None, None\n",
    "\n",
    "    src_id = src_id[0]\n",
    "    dst_id1 = pattern_dst1.findall(line)\n",
    "    dst_id2 = pattern_dst2.findall(line)\n",
    "\n",
    "    if len(dst_id1) > 0 and dst_id1[0] != 'null':\n",
    "        dst_id1 = dst_id1[0]\n",
    "    else:\n",
    "        dst_id1 = None\n",
    "\n",
    "    if len(dst_id2) > 0 and dst_id2[0] != 'null':\n",
    "        dst_id2 = dst_id2[0]\n",
    "    else:\n",
    "        dst_id2 = None\n",
    "\n",
    "    return src_id, edge_type, timestamp, dst_id1, dst_id2\n",
    "\n",
    "def process_data(file_path):\n",
    "    id_nodetype_map = {}\n",
    "    notice_num = 1000000\n",
    "    for i in range(100):\n",
    "        now_path = file_path + '.' + str(i)\n",
    "        if i == 0:\n",
    "            now_path = file_path\n",
    "        if not os.path.exists(now_path):\n",
    "            break\n",
    "\n",
    "        with open(now_path, 'r') as f:\n",
    "            show(now_path)\n",
    "            cnt = 0\n",
    "            for line in f:\n",
    "                cnt += 1\n",
    "                if cnt % notice_num == 0:\n",
    "                    print(cnt)\n",
    "\n",
    "                if 'com.bbn.tc.schema.avro.cdm18.Event' in line or 'com.bbn.tc.schema.avro.cdm18.Host' in line:\n",
    "                    continue\n",
    "\n",
    "                if 'com.bbn.tc.schema.avro.cdm18.TimeMarker' in line or 'com.bbn.tc.schema.avro.cdm18.StartMarker' in line:\n",
    "                    continue\n",
    "\n",
    "                if 'com.bbn.tc.schema.avro.cdm18.UnitDependency' in line or 'com.bbn.tc.schema.avro.cdm18.EndMarker' in line:\n",
    "                    continue\n",
    "\n",
    "                uuid = extract_uuid(line)[0]\n",
    "                subject_type = extract_subject_type(line)\n",
    "\n",
    "                if len(subject_type) < 1:\n",
    "                    if 'com.bbn.tc.schema.avro.cdm18.MemoryObject' in line:\n",
    "                        id_nodetype_map[uuid] = 'MemoryObject'\n",
    "                        continue\n",
    "                    if 'com.bbn.tc.schema.avro.cdm18.NetFlowObject' in line:\n",
    "                        id_nodetype_map[uuid] = 'NetFlowObject'\n",
    "                        continue\n",
    "                    if 'com.bbn.tc.schema.avro.cdm18.UnnamedPipeObject' in line:\n",
    "                        id_nodetype_map[uuid] = 'UnnamedPipeObject'\n",
    "                        continue\n",
    "\n",
    "                id_nodetype_map[uuid] = subject_type[0]\n",
    "\n",
    "    return id_nodetype_map\n",
    "\n",
    "def process_edges(file_path, id_nodetype_map):\n",
    "    notice_num = 1000000\n",
    "    not_in_cnt = 0\n",
    "\n",
    "    for i in range(100):\n",
    "        now_path = file_path + '.' + str(i)\n",
    "        if i == 0:\n",
    "            now_path = file_path\n",
    "        if not os.path.exists(now_path):\n",
    "            break\n",
    "\n",
    "        with open(now_path, 'r') as f, open(now_path+'.txt', 'w') as fw:\n",
    "            cnt = 0\n",
    "            for line in f:\n",
    "                cnt += 1\n",
    "                if cnt % notice_num == 0:\n",
    "                    print(cnt)\n",
    "\n",
    "                if 'com.bbn.tc.schema.avro.cdm18.Event' in line:\n",
    "                    src_id, edge_type, timestamp, dst_id1, dst_id2 = extract_edge_info(line)\n",
    "\n",
    "                    if src_id is None or src_id not in id_nodetype_map:\n",
    "                        not_in_cnt += 1\n",
    "                        continue\n",
    "\n",
    "                    src_type = id_nodetype_map[src_id]\n",
    "\n",
    "                    if dst_id1 is not None and dst_id1 in id_nodetype_map:\n",
    "                        dst_type1 = id_nodetype_map[dst_id1]\n",
    "                        this_edge1 = f\"{src_id}\\t{src_type}\\t{dst_id1}\\t{dst_type1}\\t{edge_type}\\t{timestamp}\\n\"\n",
    "                        fw.write(this_edge1)\n",
    "\n",
    "                    if dst_id2 is not None and dst_id2 in id_nodetype_map:\n",
    "                        dst_type2 = id_nodetype_map[dst_id2]\n",
    "                        this_edge2 = f\"{src_id}\\t{src_type}\\t{dst_id2}\\t{dst_type2}\\t{edge_type}\\t{timestamp}\\n\"\n",
    "                        fw.write(this_edge2)\n",
    "\n",
    "def run_data_processing():\n",
    "    os.system('tar -zxvf ta1-theia-e3-official-1r.json.tar.gz')\n",
    "    os.system('tar -zxvf ta1-theia-e3-official-6r.json.tar.gz')\n",
    "    \n",
    "    path_list = ['ta1-theia-e3-official-1r.json', 'ta1-theia-e3-official-6r.json']\n",
    "\n",
    "    for path in path_list:\n",
    "        id_nodetype_map = process_data(path)\n",
    "        process_edges(path, id_nodetype_map)\n",
    "\n",
    "    os.system('cp ta1-theia-e3-official-1r.json.txt theia_train.txt')\n",
    "    os.system('cp ta1-theia-e3-official-6r.json.8.txt theia_test.txt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "run_data_processing()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def add_node_properties(nodes, node_id, properties):\n",
    "    if node_id not in nodes:\n",
    "        nodes[node_id] = []\n",
    "    nodes[node_id].extend(properties)\n",
    "\n",
    "def update_edge_index(edges, edge_index, index):\n",
    "    for src_id, dst_id in edges:\n",
    "        src = index[src_id]\n",
    "        dst = index[dst_id]\n",
    "        edge_index[0].append(src)\n",
    "        edge_index[1].append(dst)\n",
    "\n",
    "def prepare_graph(df):\n",
    "    nodes, labels, edges = {}, {}, []\n",
    "    dummies = {\"SUBJECT_PROCESS\":\t0, \"MemoryObject\":\t1, \"FILE_OBJECT_BLOCK\":\t2,\n",
    "               \"NetFlowObject\":\t3,\"PRINCIPAL_REMOTE\":\t4,'PRINCIPAL_LOCAL':5}\n",
    "\n",
    "    for _, row in df.iterrows():\n",
    "        action = row[\"action\"]\n",
    "        properties = [row['exec'], action] + ([row['path']] if row['path'] else [])\n",
    "        \n",
    "        actor_id = row[\"actorID\"]\n",
    "        add_node_properties(nodes, actor_id, properties)\n",
    "        labels[actor_id] = dummies[row['actor_type']]\n",
    "\n",
    "        object_id = row[\"objectID\"]\n",
    "        add_node_properties(nodes, object_id, properties)\n",
    "        labels[object_id] = dummies[row['object']]\n",
    "\n",
    "        edges.append((actor_id, object_id))\n",
    "\n",
    "    features, feat_labels, edge_index, index_map = [], [], [[], []], {}\n",
    "    for node_id, props in nodes.items():\n",
    "        features.append(props)\n",
    "        feat_labels.append(labels[node_id])\n",
    "        index_map[node_id] = len(features) - 1\n",
    "\n",
    "    update_edge_index(edges, edge_index, index_map)\n",
    "\n",
    "    return features, feat_labels, edge_index, list(index_map.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fmXWs1dKIzD8",
    "tags": []
   },
   "outputs": [],
   "source": [
    "from torch_geometric.nn import GCNConv\n",
    "from torch_geometric.nn import SAGEConv, GATConv\n",
    "import torch.nn as nn\n",
    "\n",
    "class GCN(torch.nn.Module):\n",
    "    def __init__(self,in_channel,out_channel):\n",
    "        super().__init__()\n",
    "        self.conv1 = SAGEConv(in_channel, 32, normalize=True)\n",
    "        self.conv2 = SAGEConv(32, out_channel, normalize=True)\n",
    "\n",
    "    def forward(self, x, edge_index):\n",
    "        x = self.conv1(x, edge_index)\n",
    "        x = x.relu()\n",
    "        x = F.dropout(x, p=0.5, training=self.training)\n",
    "\n",
    "        x = self.conv2(x, edge_index)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "YBuP_tSq94f4",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def visualize(h, color):\n",
    "    z = TSNE(n_components=2).fit_transform(h.detach().cpu().numpy())\n",
    "\n",
    "    plt.figure(figsize=(10,10))\n",
    "    plt.xticks([])\n",
    "    plt.yticks([])\n",
    "\n",
    "    plt.scatter(z[:, 0], z[:, 1], s=70, c=color, cmap=\"Set2\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3PCP6SXwZaif",
    "tags": []
   },
   "outputs": [],
   "source": [
    "from gensim.models.callbacks import CallbackAny2Vec\n",
    "import gensim\n",
    "from gensim.models import Word2Vec\n",
    "from multiprocessing import Pool\n",
    "from itertools import compress\n",
    "from tqdm import tqdm\n",
    "import time\n",
    "\n",
    "class EpochSaver(CallbackAny2Vec):\n",
    "\n",
    "    def __init__(self):\n",
    "        self.epoch = 0\n",
    "\n",
    "    def on_epoch_end(self, model):\n",
    "        model.save('word2vec_theia_E3.model')\n",
    "        self.epoch += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "P8oBL8LFaeOf",
    "tags": []
   },
   "outputs": [],
   "source": [
    "class EpochLogger(CallbackAny2Vec):\n",
    "\n",
    "    def __init__(self):\n",
    "        self.epoch = 0\n",
    "\n",
    "    def on_epoch_begin(self, model):\n",
    "        print(\"Epoch #{} start\".format(self.epoch))\n",
    "\n",
    "    def on_epoch_end(self, model):\n",
    "        print(\"Epoch #{} end\".format(self.epoch))\n",
    "        self.epoch += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Se7Ei4tAapVj",
    "tags": []
   },
   "outputs": [],
   "source": [
    "logger = EpochLogger()\n",
    "saver = EpochSaver()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3RDmGME5iPb5",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def add_attributes(d,p):\n",
    "    \n",
    "    f = open(p)\n",
    "    data = [json.loads(x) for x in f if \"EVENT\" in x]\n",
    "\n",
    "    info = []\n",
    "    for x in data:\n",
    "        try:\n",
    "            action = x['datum']['com.bbn.tc.schema.avro.cdm18.Event']['type']\n",
    "        except:\n",
    "            action = ''\n",
    "        try:\n",
    "            actor = x['datum']['com.bbn.tc.schema.avro.cdm18.Event']['subject']['com.bbn.tc.schema.avro.cdm18.UUID']\n",
    "        except:\n",
    "            actor = ''\n",
    "        try:\n",
    "            obj = x['datum']['com.bbn.tc.schema.avro.cdm18.Event']['predicateObject']['com.bbn.tc.schema.avro.cdm18.UUID']\n",
    "        except:\n",
    "            obj = ''\n",
    "        try:\n",
    "            timestamp = x['datum']['com.bbn.tc.schema.avro.cdm18.Event']['timestampNanos']\n",
    "        except:\n",
    "            timestamp = ''\n",
    "        try:\n",
    "            cmd = x['datum']['com.bbn.tc.schema.avro.cdm18.Event']['properties']['map']['cmdLine']\n",
    "        except:\n",
    "            cmd = ''\n",
    "        try:\n",
    "            path = x['datum']['com.bbn.tc.schema.avro.cdm18.Event']['predicateObjectPath']['string']\n",
    "        except:\n",
    "            path = ''\n",
    "        try:\n",
    "            path2 = x['datum']['com.bbn.tc.schema.avro.cdm18.Event']['predicateObject2Path']['string']\n",
    "        except:\n",
    "            path2 = ''\n",
    "        try:\n",
    "            obj2 = x['datum']['com.bbn.tc.schema.avro.cdm18.Event']['predicateObject2']['com.bbn.tc.schema.avro.cdm18.UUID']\n",
    "            info.append({'actorID':actor,'objectID':obj2,'action':action,'timestamp':timestamp,'exec':cmd, 'path':path2})\n",
    "        except:\n",
    "            pass\n",
    "\n",
    "        info.append({'actorID':actor,'objectID':obj,'action':action,'timestamp':timestamp,'exec':cmd, 'path':path})\n",
    "\n",
    "    rdf = pd.DataFrame.from_records(info).astype(str)\n",
    "    d = d.astype(str)\n",
    "\n",
    "    return d.merge(rdf,how='inner',on=['actorID','objectID','action','timestamp']).drop_duplicates()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "if Train:\n",
    "    f = open(\"theia_train.txt\")\n",
    "    data = f.read().split('\\n')\n",
    "    data = [line.split('\\t') for line in data]\n",
    "    df = pd.DataFrame (data, columns = ['actorID', 'actor_type','objectID','object','action','timestamp'])\n",
    "    df = df.dropna()\n",
    "    df.sort_values(by='timestamp', ascending=True,inplace=True)\n",
    "    df = add_attributes(df,\"ta1-theia-e3-official-1r.json\")\n",
    "    phrases,labels,edges,mapp = prepare_graph(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "p3TAi69zI1bO",
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sklearn.utils import class_weight\n",
    "import torch.nn.functional as F\n",
    "from torch.nn import CrossEntropyLoss\n",
    "\n",
    "model = GCN(30,5).to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "if Train:\n",
    "    word2vec = Word2Vec(sentences=phrases, vector_size=30, window=5, min_count=1, workers=8,epochs=300,callbacks=[saver,logger])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Vn_pMyt5Jd-6",
    "tags": []
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "import numpy as np\n",
    "from gensim.models import Word2Vec\n",
    "\n",
    "class PositionalEncoder:\n",
    "\n",
    "    def __init__(self, d_model, max_len=100000):\n",
    "        position = torch.arange(max_len).unsqueeze(1)\n",
    "        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))\n",
    "        self.pe = torch.zeros(max_len, d_model)\n",
    "        self.pe[:, 0::2] = torch.sin(position * div_term)\n",
    "        self.pe[:, 1::2] = torch.cos(position * div_term)\n",
    "\n",
    "    def embed(self, x):\n",
    "        return x + self.pe[:x.size(0)]\n",
    "\n",
    "def infer(document):\n",
    "    word_embeddings = [w2vmodel.wv[word] for word in document if word in  w2vmodel.wv]\n",
    "    \n",
    "    if not word_embeddings:\n",
    "        return np.zeros(20)\n",
    "\n",
    "    output_embedding = torch.tensor(word_embeddings, dtype=torch.float)\n",
    "    if len(document) < 100000:\n",
    "        output_embedding = encoder.embed(output_embedding)\n",
    "\n",
    "    output_embedding = output_embedding.detach().cpu().numpy()\n",
    "    return np.mean(output_embedding, axis=0)\n",
    "\n",
    "encoder = PositionalEncoder(30)\n",
    "w2vmodel = Word2Vec.load(\"trained_weights/theia/word2vec_theia_E3.model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 689309,
     "status": "ok",
     "timestamp": 1673566932746,
     "user": {
      "displayName": "Mati Ur Rehman",
      "userId": "04281203290774044297"
     },
     "user_tz": 300
    },
    "id": "Gclj6HVL17lD",
    "outputId": "f60fadea-a7db-471f-defe-fee744f6ef25",
    "tags": []
   },
   "outputs": [],
   "source": [
    "from torch_geometric import utils\n",
    "\n",
    "if Train:\n",
    "    l = np.array(labels)\n",
    "    class_weights = class_weight.compute_class_weight(class_weight = \"balanced\",classes = np.unique(l),y = l)\n",
    "    class_weights = torch.tensor(class_weights,dtype=torch.float).to(device)\n",
    "    criterion = CrossEntropyLoss(weight=class_weights,reduction='mean')\n",
    "\n",
    "    nodes = [infer(x) for x in phrases]\n",
    "    nodes = np.array(nodes)  \n",
    "\n",
    "    graph = Data(x=torch.tensor(nodes,dtype=torch.float).to(device),y=torch.tensor(labels,dtype=torch.long).to(device), edge_index=torch.tensor(edges,dtype=torch.long).to(device))\n",
    "    graph.n_id = torch.arange(graph.num_nodes)\n",
    "    mask = torch.tensor([True]*graph.num_nodes, dtype=torch.bool)\n",
    "\n",
    "    for m_n in range(20):\n",
    "\n",
    "      loader = NeighborLoader(graph, num_neighbors=[-1,-1], batch_size=5000,input_nodes=mask)\n",
    "      total_loss = 0\n",
    "      for subg in loader:\n",
    "          model.train()\n",
    "          optimizer.zero_grad() \n",
    "          out = model(subg.x, subg.edge_index) \n",
    "          loss = criterion(out, subg.y) \n",
    "          loss.backward() \n",
    "          optimizer.step()      \n",
    "          total_loss += loss.item() * subg.batch_size\n",
    "      print(total_loss / mask.sum().item())\n",
    "\n",
    "      loader = NeighborLoader(graph, num_neighbors=[-1,-1], batch_size=5000,input_nodes=mask)\n",
    "      for subg in loader:\n",
    "          model.eval()\n",
    "          out = model(subg.x, subg.edge_index)\n",
    "\n",
    "          sorted, indices = out.sort(dim=1,descending=True)\n",
    "          conf = (sorted[:,0] - sorted[:,1]) / sorted[:,0]\n",
    "          conf = (conf - conf.min()) / conf.max()\n",
    "\n",
    "          pred = indices[:,0]\n",
    "          cond = (pred == subg.y) | (conf >= 0.9)\n",
    "          mask[subg.n_id[cond]] = False\n",
    "\n",
    "      torch.save(model.state_dict(), f'lword2vec_gnn_theia{m_n}_E3.pth')\n",
    "      print(f'Model# {m_n}. {mask.sum().item()} nodes still misclassified \\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vgqyu7E5qPet",
    "tags": []
   },
   "outputs": [],
   "source": [
    "from itertools import compress\n",
    "from torch_geometric import utils\n",
    "\n",
    "def Get_Adjacent(ids, mapp, edges, hops):\n",
    "    if hops == 0:\n",
    "        return set()\n",
    "    \n",
    "    neighbors = set()\n",
    "    for edge in zip(edges[0], edges[1]):\n",
    "        if any(mapp[node] in ids for node in edge):\n",
    "            neighbors.update(mapp[node] for node in edge)\n",
    "\n",
    "    if hops > 1:\n",
    "        neighbors = neighbors.union(Get_Adjacent(neighbors, mapp, edges, hops - 1))\n",
    "    \n",
    "    return neighbors\n",
    "\n",
    "def calculate_metrics(TP, FP, FN, TN):\n",
    "    FPR = FP / (FP + TN) if FP + TN > 0 else 0\n",
    "    TPR = TP / (TP + FN) if TP + FN > 0 else 0\n",
    "\n",
    "    prec = TP / (TP + FP) if TP + FP > 0 else 0\n",
    "    rec = TP / (TP + FN) if TP + FN > 0 else 0\n",
    "    fscore = (2 * prec * rec) / (prec + rec) if prec + rec > 0 else 0\n",
    "\n",
    "    return prec, rec, fscore, FPR, TPR\n",
    "\n",
    "def helper(MP, all_pids, GP, edges, mapp):\n",
    "    TP = MP.intersection(GP)\n",
    "    FP = MP - GP\n",
    "    FN = GP - MP\n",
    "    TN = all_pids - (GP | MP)\n",
    "\n",
    "    two_hop_gp = Get_Adjacent(GP, mapp, edges, 2)\n",
    "    two_hop_tp = Get_Adjacent(TP, mapp, edges, 2)\n",
    "    FPL = FP - two_hop_gp\n",
    "    TPL = TP.union(FN.intersection(two_hop_tp))\n",
    "    FN = FN - two_hop_tp\n",
    "\n",
    "    TP, FP, FN, TN = len(TPL), len(FPL), len(FN), len(TN)\n",
    "\n",
    "    prec, rec, fscore, FPR, TPR = calculate_metrics(TP, FP, FN, TN)\n",
    "    print(f\"True Positives: {TP}, False Positives: {FP}, False Negatives: {FN}\")\n",
    "    print(f\"Precision: {round(prec, 2)}, Recall: {round(rec, 2)}, Fscore: {round(fscore, 2)}\")\n",
    "    \n",
    "    return TPL, FPL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "OZFrSLVZ29qU",
    "tags": []
   },
   "outputs": [],
   "source": [
    "f = open(\"theia_test.txt\")\n",
    "data = f.read().split('\\n')\n",
    "data = [line.split('\\t') for line in data]\n",
    "df = pd.DataFrame (data, columns = ['actorID', 'actor_type','objectID','object','action','timestamp'])\n",
    "df = df.dropna()\n",
    "df.sort_values(by='timestamp', ascending=True,inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "df = add_attributes(df,\"ta1-theia-e3-official-6r.json.8\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "with open(\"data_files/theia.json\", \"r\") as json_file:\n",
    "    GT_mal = set(json.load(json_file))\n",
    "\n",
    "data = df\n",
    "\n",
    "phrases,labels,edges,mapp = prepare_graph(data)\n",
    "nodes = [infer(x) for x in phrases]\n",
    "nodes = np.array(nodes)  \n",
    "et = time.time()\n",
    "\n",
    "all_ids = list(data['actorID']) + list(data['objectID'])\n",
    "all_ids = set(all_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 15919,
     "status": "ok",
     "timestamp": 1673572401286,
     "user": {
      "displayName": "Mati Ur Rehman",
      "userId": "04281203290774044297"
     },
     "user_tz": 300
    },
    "id": "DsLlVS6zpox5",
    "outputId": "93d85078-5348-4225-e179-162b682ec426",
    "tags": []
   },
   "outputs": [],
   "source": [
    "graph = Data(x=torch.tensor(nodes,dtype=torch.float).to(device),y=torch.tensor(labels,dtype=torch.long).to(device), edge_index=torch.tensor(edges,dtype=torch.long).to(device))\n",
    "graph.n_id = torch.arange(graph.num_nodes)\n",
    "flag = torch.tensor([True]*graph.num_nodes, dtype=torch.bool)\n",
    "\n",
    "for m_n in range(20):\n",
    "  model.load_state_dict(torch.load(f'trained_weights/theia/lword2vec_gnn_theia{m_n}_E3.pth',map_location=torch.device('cpu')))\n",
    "  loader = NeighborLoader(graph, num_neighbors=[-1,-1], batch_size=5000)    \n",
    "  for subg in loader:\n",
    "      model.eval()\n",
    "      out = model(subg.x, subg.edge_index)\n",
    "\n",
    "      sorted, indices = out.sort(dim=1,descending=True)\n",
    "      conf = (sorted[:,0] - sorted[:,1]) / sorted[:,0]\n",
    "      conf = (conf - conf.min()) / conf.max()\n",
    "    \n",
    "      pred = indices[:,0]\n",
    "      cond = (pred == subg.y) & (conf > 0.53)\n",
    "      flag[subg.n_id[cond]] = torch.logical_and(flag[subg.n_id[cond]], torch.tensor([False]*len(flag[subg.n_id[cond]]), dtype=torch.bool))\n",
    "\n",
    "index = utils.mask_to_index(flag).tolist()\n",
    "ids = set([mapp[x] for x in index])\n",
    "alerts = helper(set(ids),set(all_ids),GT_mal,edges,mapp) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def traverse(ids, mapping, edges, hops, visited=None):\n",
    "    if hops == 0:\n",
    "        return set()\n",
    "\n",
    "    if visited is None:\n",
    "        visited = set()\n",
    "\n",
    "    neighbors = set()\n",
    "    for src, dst in zip(edges[0], edges[1]):\n",
    "        src_mapped, dst_mapped = mapping[src], mapping[dst]\n",
    "\n",
    "        if (src_mapped in ids and dst_mapped not in visited) or \\\n",
    "           (dst_mapped in ids and src_mapped not in visited):\n",
    "            neighbors.add(src_mapped)\n",
    "            neighbors.add(dst_mapped)\n",
    "\n",
    "        visited.add(src_mapped)\n",
    "        visited.add(dst_mapped)\n",
    "\n",
    "    neighbors.difference_update(ids) \n",
    "    return ids.union(traverse(neighbors, mapping, edges, hops - 1, visited))\n",
    "\n",
    "def load_data(file_path):\n",
    "    with open(file_path, 'r') as file:\n",
    "        return json.load(file)\n",
    "\n",
    "def find_connected_alerts(start_alert, mapping, edges, depth, remaining_alerts):\n",
    "    connected_path = traverse({start_alert}, mapping, edges, depth)\n",
    "    return connected_path.intersection(remaining_alerts)\n",
    "\n",
    "def generate_incident_graphs(alerts, edges, mapping, depth):\n",
    "    incident_graphs = []\n",
    "    remaining_alerts = set(alerts)\n",
    "\n",
    "    while remaining_alerts:\n",
    "        alert = remaining_alerts.pop()\n",
    "        connected_alerts = find_connected_alerts(alert, mapping, edges, depth, remaining_alerts)\n",
    "\n",
    "        if len(connected_alerts) > 1:\n",
    "            incident_graphs.append(connected_alerts)\n",
    "            remaining_alerts -= connected_alerts\n",
    "\n",
    "    return incident_graphs\n"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "PyTorch 1.12.0",
   "language": "python",
   "name": "pytorch-1.12.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
